from config import *
from data import truncate_pad, token_transform_en
from transformer import TransformerEncoder, TransformerDecoder, EncoderDecoder



def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,
                    device, save_attention_weights=False):
    def voc(line):
        return [src_vocab[i] for i in line]

    #"""Predict for sequence to sequence."""
    # Set `net` to eval mode for inference
    net.eval()
    src_tokens = voc(token_transform_en(src_sentence.lower())) + [
        src_vocab['<eos>']]
    enc_valid_len = torch.tensor([len(src_tokens)], device=device)
    # 截断、填充文本序列
    src_tokens = truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
    # Add the batch axis
    enc_X = torch.unsqueeze(
        torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
    enc_outputs = net.encoder(enc_X, enc_valid_len)
    dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
    # Add the batch axis
    dec_X = torch.unsqueeze(
        torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device),
        dim=0)
    output_seq, attention_weight_seq = [], []
    for _ in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state)

        dec_X = Y.argmax(dim=2)  # 最大的下标，即为最终词下标
        pred = dec_X.squeeze(dim=0).type(torch.int32).item()  # 去掉第一维取结果
        # Save attention weights (to be covered later)
        if save_attention_weights:
            attention_weight_seq.append(net.decoder.attention_weights)
        # Once the end-of-sequence token is predicted, the generation of the
        # output sequence is complete
        if pred == tgt_vocab['<eos>']:
            break
        output_seq.append(pred)
    return ''.join(tgt_vocab.lookup_tokens(output_seq)), attention_weight_seq


if __name__=="__main__":
    src_vocab = torch.load(DATA_ROOT + "eng_vocab{}".format(num_examples))
    tgt_vocab = torch.load(DATA_ROOT + "zh_vocab{}".format(num_examples))
    #encoder = torch.load(MODEL_ROOT + "trans_encoder{}.mdl".format(num_examples), map_location= lambda storage, loc: storage)
    #decoder = torch.load(MODEL_ROOT + "trans_decoder{}.mdl".format(num_examples), map_location= lambda storage, loc: storage)

    encoder = torch.load(MODEL_ROOT + "trans_encoder{}.mdl".format(num_examples), map_location= {'cuda:1' : 'cuda:0'})
    decoder = torch.load(MODEL_ROOT + "trans_decoder{}.mdl".format(num_examples), map_location= {'cuda:1' : 'cuda:0'})


    net = EncoderDecoder(encoder, decoder)
    while 1:
        sentence = input()
        out, weight = predict_seq2seq(net, sentence, src_vocab, tgt_vocab, num_steps, device)
        print(out)

